#!/usr/bin/env python3
#
# Python script to run and analyse MMS test
#
from __future__ import division
from __future__ import print_function

from boututils.run_wrapper import shell_safe, launch_safe
from boutdata.collect import collect

from numpy import array, log, polyfit, linspace, arange

import pickle

from sys import stdout

import zoidberg as zb

nx = 3  # Not changed for these tests

# Resolution in y and z
nlist = [8, 16, 32, 64, 128]

# Number of parallel slices (in each direction)
nslices = [1, 2]

directory = "data"

nproc = 2
mthread = 2



success = True

error_2 = {}
error_inf = {}
method_orders = {}

# Run with periodic Y?
yperiodic = True

failures = []

print("Making fci MMS test")
shell_safe("make > make.log")

for nslice in nslices:
    error_2[nslice] = []
    error_inf[nslice] = []

    # Which central difference scheme to use and its expected order
    order = nslice * 2
    method_orders[nslice] = {
        "name": "C{}".format(order),
        "order": order
    }

    for n in nlist:
        # Define the magnetic field using new poloidal gridding method
        # Note that the Bz and Bzprime parameters here must be the same as in mms.py
        field = zb.field.Slab(Bz=0.05, Bzprime=0.1)
        # Create rectangular poloidal grids
        poloidal_grid = zb.poloidal_grid.RectangularPoloidalGrid(nx, n, 0.1, 1.)
        # Set the ylength and y locations
        ylength = 10.

        if yperiodic:
            ycoords = linspace(0.0, ylength, n, endpoint=False)
        else:
            # Doesn't include the end points
            ycoords = (arange(n) + 0.5)*ylength/float(n)

        # Create the grid
        grid = zb.grid.Grid(poloidal_grid, ycoords, ylength, yperiodic=yperiodic)
        # Make and write maps
        maps = zb.make_maps(grid, field, nslice=nslice, quiet=True)
        zb.write_maps(grid, field, maps, new_names=False, metric2d=True, quiet=True)

        args = (" MZ={} MYG={} fci:y_periodic={} mesh:ddy:first={}"
                .format(n, nslice, yperiodic, method_orders[nslice]["name"]))

        # Command to run
        cmd = "./fci_mms "+args

        print("Running command: "+cmd)

        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, mthread=mthread, pipe=True)

        # Save output to log file
        with open("run.log."+str(n), "w") as f:
            f.write(out)

        if s:
            print("Run failed!\nOutput was:\n")
            print(out)
            exit(s)

        # Collect data
        l_2 = collect("l_2", tind=[1, 1], info=False,
                      path=directory, xguards=False, yguards=False)
        l_inf = collect("l_inf", tind=[1, 1], info=False,
                        path=directory, xguards=False, yguards=False)

        error_2[nslice].append(l_2)
        error_inf[nslice].append(l_inf)

        print("Errors : l-2 {:f} l-inf {:f}".format(l_2, l_inf))

    dx = 1. / array(nlist)

    # Calculate convergence order
    fit = polyfit(log(dx), log(error_2[nslice]), 1)
    order = fit[0]
    stdout.write("Convergence order = {:f} (fit)".format(order))

    order = log(error_2[nslice][-2]/error_2[nslice][-1])/log(dx[-2]/dx[-1])
    stdout.write(", {:f} (small spacing)".format(order))

    # Should be close to the expected order
    if order > method_orders[nslice]["order"] * 0.95:
        print("............ PASS\n")
    else:
        print("............ FAIL\n")
        success = False
        failures.append(method_orders[nslice]["name"])


with open("fci_mms.pkl", "wb") as output:
    pickle.dump(nlist, output)
    for nslice in nslices:
        pickle.dump(error_2[nslice], output)
        pickle.dump(error_inf[nslice], output)

# Do we want to show the plot as well as save it to file.
showPlot = True

if False:
    try:
        # Plot using matplotlib if available
        import matplotlib.pyplot as plt

        fig, ax = plt.subplots(1, 1)

        for nslice in nslices:
            ax.plot(dx, error_2[nslice], '-',
                    label="{} $l_2$".format(method_orders[nslice]["name"]))
            ax.plot(dx, error_inf[nslice], '--',
                    label="{} $l_\inf$".format(method_orders[nslice]["name"]))
        ax.legend(loc="upper left")
        ax.grid()
        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_title('error scaling')
        ax.set_xlabel(r'Mesh spacing $\delta x$')
        ax.set_ylabel("Error norm")

        plt.savefig("fci_mms.pdf")

        print("Plot saved to fci_mms.pdf")

        if showPlot:
            plt.show()
        plt.close()
    except ImportError:
        print("No matplotlib")

if success:
    print("All tests passed")
    exit(0)
else:
    print("Some tests failed:")
    for failure in failures:
        print("\t" + failure)
    exit(1)
